{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This Notebook shows how to do streaming ingestion into TensorFlow using TensorFlow-IO and the Kafka plugin.\n",
    "\n",
    "# Steps:\n",
    "# 1) Produce streaming messages to a Kafka topic\n",
    "# 2) Consume streaming messages with KafkaDataSet (a subclass of tf.data.Dataset)\n",
    "# 3) Define and train model using Keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import confluent_kafka as kafka"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. MNIST Kafka Producer, run separately\n",
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
    "print(\"train: \", (x_train.shape, y_train.shape))\n",
    "\n",
    "producer = kafka.Producer({'bootstrap.servers': 'localhost:9092'})\n",
    "count = 0\n",
    "for (x, y) in zip(x_train, y_train):\n",
    "  \n",
    "  producer.poll(0)\n",
    "  producer.produce('xx', x.tobytes())\n",
    "  producer.produce('yy', y.tobytes())\n",
    "  count += 1\n",
    "print(\"count(x, y): \", count)\n",
    "producer.flush()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_io.kafka as kafka_io\n",
    "import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. KafkaDataset with map function\n",
    "def func_x(x):\n",
    "  # Decode image to (28, 28)\n",
    "  x = tf.io.decode_raw(x, out_type=tf.uint8)\n",
    "  x = tf.reshape(x, [28, 28])\n",
    "  # Convert to float32 for tf.keras\n",
    "  x = tf.image.convert_image_dtype(x, tf.float32)\n",
    "  return x\n",
    "def func_y(y):\n",
    "  # Decode image to (,)\n",
    "  y = tf.io.decode_raw(y, out_type=tf.uint8)\n",
    "  y = tf.reshape(y, [])\n",
    "  return y\n",
    "train_images = kafka_io.KafkaDataset(['xx:0'], group='xx', eof=True).map(func_x)\n",
    "train_labels = kafka_io.KafkaDataset(['yy:0'], group='yy', eof=True).map(func_y)\n",
    "train_kafka = tf.data.Dataset.zip((train_images, train_labels)).batch(1)\n",
    "print(train_kafka)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. Keras model\n",
    "model = tf.keras.Sequential([\n",
    "    tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
    "    tf.keras.layers.Dense(128, activation=tf.nn.relu),\n",
    "    tf.keras.layers.Dense(10, activation=tf.nn.softmax)\n",
    "])\n",
    "model.compile(optimizer='adam',\n",
    "              loss='sparse_categorical_crossentropy',\n",
    "              metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 4. Add TensorBoard to monitor the model training\n",
    "log_dir=\"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 5. Train model\n",
    "\n",
    "# default: 5 epochs and 12000 steps\n",
    "model.fit(train_kafka, epochs=1, steps_per_epoch=1000, callbacks=[tensorboard_callback])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
